{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neo4j import GraphDatabase\n",
    "import pandas as pd\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "from sklearn.metrics import f1_score\n",
    "from sklearn.multioutput import MultiOutputClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neo4j driver\n",
    "driver = GraphDatabase.driver('neo4j://localhost:7687', auth=('neo4j', 'letmein'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agenda\n",
    "In this example, you will reproduce the protein role classification task from the original GraphSAGE article. The task is to classify protein roles in terms of their cellular function across various protein-protein interaction graphs (PPI). The dataset contains 22 PPI graphs, with each graph corresponding to a different human tissue. The average PPI graph contains 2373 nodes, with an average degree of 28.8. There are available predefined positional gene sets, motif gene sets, and immunological signatures for each protein in the network. Based on those features and their connections, you will predict the roles of proteins in the network. You will train both the classification and GraphSAGE model on 20 graphs and then average prediction F1 scores on two test graphs.\n",
    "## Graph model\n",
    "As mentioned, we are dealing with a protein-protein interaction network. This is a monopartite network, where nodes represent proteins and relationships represent their interactions.\n",
    "\n",
    "Additionally, the protein nodes have the predefined features stored as a property. The embeddings_all property contains all 50 features stored as a list of floats. I have also prepared the decoupled properties, where the embedding_x property holds a single feature and x ranges from 0 to 49. You will see later in the blog post why the decoupled properties are required. The protein nodes also contain a secondary label that could be either Train or Test. With the help of the secondary label, you can easily perform the train-test data split.\n",
    "\n",
    "## Classification using predefined features\n",
    "To get a baseline f1 score, you will first train the classification model using only the predefined features available for proteins. The code is identical to the code found in the official GraphSAGE repository, where they used the Stochastic Gradient Descent classifier model to train and predict protein roles. The only difference is that here you will be fetching the data from a Neo4j database instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_train_data_query = \"\"\"\n",
    "MATCH (t:Train)\n",
    "RETURN t.class as class, t.embeddings_all as features\n",
    "\"\"\"\n",
    "\n",
    "raw_test_data_query = \"\"\"\n",
    "MATCH (t:Test)\n",
    "RETURN t.class as class, t.embeddings_all as features\n",
    "\"\"\"\n",
    "\n",
    "with driver.session() as session:\n",
    "    # Fetch training data\n",
    "    train_results = session.run(raw_train_data_query)\n",
    "    train_results_df = pd.DataFrame([dict(r) for r in train_results])\n",
    "    \n",
    "    # Fetch test data\n",
    "    test_results = session.run(raw_test_data_query)\n",
    "    test_results_df = pd.DataFrame([dict(r) for r in test_results])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log = MultiOutputClassifier(SGDClassifier(loss=\"log\"), n_jobs=10)\n",
    "log.fit(train_results_df['features'].to_list(), train_results_df['class'].to_list())\n",
    "\n",
    "print(f1_score(test_results_df['class'].to_list(), \n",
    "               log.predict(test_results_df['features'].to_list()), average=\"micro\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before you can execute any graph algorithms, you have to project the in-memory graph via the Graph Loader component. You can use either native projection or cypher projection to load the in-memory graph. \n",
    "In this example, you will use the native projection feature to load the in-memory graph. To start, you will project the training data and store it as a named graph in the Graph Catalog. The current implementation of the GraphSAGE algorithm supports only node features that are of type Float. For this reason, you will include the decoupled node properties ranging from embedding_0 to embedding_49 in the graph projection instead of a single property embeddings_all, which holds all the node features in the form of a list of Floats. Additionally, you will treat the projected graph as undirected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with driver.session() as session:\n",
    "    session.run(\"\"\"UNWIND range(0,49) as i\n",
    "                   WITH collect('embedding_' + toString(i)) as embeddings\n",
    "                   CALL gds.graph.create('train','Train',\n",
    "                    {INTERACTS:{orientation:'UNDIRECTED'}}, {nodeProperties:embeddings}) \n",
    "                   YIELD graphName, nodeCount, relationshipCount\n",
    "                   RETURN graphName, nodeCount, relationshipCount\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, you will train the GraphSAGE model. The model's hyper-parameter settings were mostly copied from the original paper. I have noticed that the lower learning-rate setting had the most impact on the downstream classification accuracy. Another import hyper-parameter is the samplingSizes parameter, where the size of the list determines the number of layers (defined as K parameter in the paper), and the values determine how many nodes will be sampled by the layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with driver.session() as session:\n",
    "    session.run(\"\"\"\n",
    "        UNWIND range(0,49) as i\n",
    "        WITH collect('embedding_' + toString(i)) as embeddings\n",
    "        CALL gds.beta.graphSage.train('train',{\n",
    "          modelName:'proteinModel',\n",
    "          aggregator:'pool',\n",
    "          batchSize:512,\n",
    "          activationFunction:'relu',\n",
    "          epochs:10,\n",
    "          sampleSizes:[25,10],\n",
    "          learningRate:0.0000001,\n",
    "          embeddingDimension:256,\n",
    "          featureProperties:embeddings})\n",
    "        YIELD modelInfo\n",
    "        RETURN modelInfo\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training process took around 20 minutes on my laptop. After the training process finishes, the GraphSAGE model will be stored in the model catalog. You can now use this model to induce node embeddings on any projected graph with the same node properties used during the training. Before testing the downstream classification accuracy, you have to load the test data as an in-memory graph in the Graph Catalog."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "with driver.session() as session:\n",
    "    session.run(\"\"\"\n",
    "        UNWIND range(0,49) as i\n",
    "        WITH collect('embedding_' + toString(i)) as embeddings\n",
    "        CALL gds.graph.create('test','Test',{INTERACTS:{orientation:'UNDIRECTED'}}, \n",
    "          {nodeProperties:embeddings}) \n",
    "        YIELD graphName, nodeCount, relationshipCount\n",
    "        RETURN graphName, nodeCount, relationshipCount\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the GraphSAGE model trained and both the training and test data projected as an in-memory graph, you can go ahead and calculate the f1 score using the GraphSAGE embeddings in a downstream classification model. Remember, the GraphSAGE model has not observed the test data during the training phase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphsage_train_data = \"\"\"\n",
    "CALL gds.beta.graphSage.stream('train', {modelName:'proteinModel'})\n",
    "YIELD nodeId, embedding\n",
    "RETURN gds.util.asNode(nodeId).class as class, embedding as features\n",
    "\"\"\"\n",
    "\n",
    "graphsage_test_data = \"\"\"\n",
    "CALL gds.beta.graphSage.stream('test', {modelName:'proteinModel'})\n",
    "YIELD nodeId, embedding\n",
    "RETURN gds.util.asNode(nodeId).class as class, embedding as features\n",
    "\"\"\"\n",
    "with driver.session() as session:\n",
    "    # Fetch training data\n",
    "    train_results = session.run(graphsage_train_data)\n",
    "    train_results_df = pd.DataFrame([dict(t) for t in train_results])\n",
    "    # Fetch test data\n",
    "    test_results = session.run(graphsage_test_data)\n",
    "    test_results_df = pd.DataFrame([dict(t) for t in test_results])\n",
    "\n",
    "# Train the SGD classifier\n",
    "log = MultiOutputClassifier(SGDClassifier(loss=\"log\"), n_jobs=10)\n",
    "log.fit(train_results_df['features'].to_list(), train_results_df['class'].to_list())\n",
    "\n",
    "# Calculate the f1 score on test data\n",
    "print(f1_score(test_results_df['class'].to_list(), \n",
    "               log.predict(test_results_df['features'].to_list()), average=\"micro\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlp",
   "language": "python",
   "name": "nlp"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
